#!/bin/env python
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from mystat import MYSTAT


#----------
# Parameters
#----------
# Observed Time series
amefl = "./Data/ame100mm.csv"

# Observed Time series
datafl = "./Data/anomalousdays_observations.json"

# HiFLOR Time series
datafl2 = "./Data/anomalousdays_HiFLOR.json"

#----------
# Subroutine
#----------
def linreg(X,Y):
    """
     return a,b in solution to y = ax + b such that root mean square distance
     between trend line and original points is minimized
    """
    N = len(X)
    Sx = Sy = Sxx = Syy = Sxy = 0.0

    for x, y in zip(X, Y):
        Sx = Sx + x
        Sy = Sy + y
        Sxx = Sxx + x*x
        Syy = Syy + y*y
        Sxy = Sxy + x*y
    det = Sxx * N - Sx * Sx
    aa = (Sxy * N - Sy * Sx)/det
    bb = (Sxx * Sy - Sx * Sxy)/det
    YY2 = aa * X + bb
    return X, YY2, aa

def mk_test(x, alpha = 0.05):
    from scipy.stats import norm
    """
    this perform the MK (Mann-Kendall) test to check if the trend is present in
    data or not

    Input:
        x:   a vector of data
        alpha: significance level

    Output:
        h: True (if trend is present) or False (if trend is absence)
        p: p value of the sifnificance test

    Examples
    --------
      >>> x = np.random.rand(100)
      >>> h,p = mk_test(x,0.05)  # meteo.dat comma delimited
    """
    n = len(x)

    # calculate S
    s = 0
    for k in range(n-1):
        for j in range(k+1,n):
            s += np.sign(x[j] - x[k])

    # calculate the unique data
    unique_x = np.unique(x)
    g = len(unique_x)

    # calculate the var(s)
    if n == g: # there is no tie
        var_s = (n*(n-1)*(2*n+5))/18
    else: # there are some ties in data
        tp = np.zeros(unique_x.shape)
        for i in range(len(unique_x)):
            tp[i] = sum(unique_x[i] == x)
        var_s = (n*(n-1)*(2*n+5) + np.sum(tp*(tp-1)*(2*tp+5)))/18

    if s>0:
        z = (s - 1)/np.sqrt(var_s)
    elif s == 0:
        z = 0
    elif s<0:
        z = (s + 1)/np.sqrt(var_s)

    # calculate the p_value
    p = 2*(1-norm.cdf(abs(z))) # two tail test
    h = abs(z) > norm.ppf(1-alpha/2)

    return h, p

#----------
# Read Data
#----------
dfa1 = pd.read_csv(amefl, delimiter=',', header=0)
alldata = pd.read_json(datafl, typ='series')
alldata2 = pd.read_json(datafl2)
years2 = alldata2.index
alldata2 = pd.read_json(datafl2).to_numpy()


plt.figure(0, figsize=(10,5))
ax1 = plt.subplot(1,1,1)

#----------
# Plot
#----------
mystat = MYSTAT()

# Observed
years = alldata.index
alldata = alldata.to_numpy()
h, pval = mk_test(alldata, alpha=0.05)
ax1.plot(years, alldata, "o-", ms=9, lw=2, color="black", label="Observed (P-value=%.2f)" % (pval), zorder=200)
if pval <=0.05:
    xx, yy, aa = linreg(years,alldata)
    ax1.plot(xx, yy, "--", color="black", lw=3, zorder=200)
print ("obs trend=",aa)
ax1.set_ylabel("Anomalous days", fontsize=14)
ax1.set_xlabel("Year", fontsize=14)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)

# HiFLOR AllForc
ensmin = alldata2.min(axis=1)
ensmax = alldata2.max(axis=1)
ensmean = alldata2.mean(axis=1)

h1, pval1 = mk_test(ensmean, alpha=0.05)
plt.fill_between(years2, ensmin, ensmax, color="red", alpha=0.2)
plt.plot(years2, ensmean, "o-", ms=9, lw=2, color="red", label="HiFLOR (P-value=%.2f)" % (pval1), zorder=150)
if pval1 <=0.05:
    xx1, yy1, aa1 = linreg(years2,ensmean)
    plt.plot(xx1, yy1, "--", color="red", lw=3, zorder=150)

plt.legend(loc=2, prop={"size":12})
ax1.set_ylim(-2,36)

# APHRODITE 100mm accumeration precip
ax2=ax1.twinx()
label2 = r"Annual days of 100 $\rm mm\ 5 days^{-1}$"
label2 += "\n"
label2 += r"or more per grid point [blue]"

years2 = np.array(dfa1["year"].tolist())
amedata = np.array(dfa1["days"].tolist())

ax2.plot(years2,amedata,'o-',ms=8, lw=2, color="gray", label=label2, zorder=100)

ax2.set_ylabel(label2, fontsize=14)
h, pval3 = mk_test(amedata, alpha=0.05)

if pval3 <=0.05:
    xx, yy, aa = linreg(years2,amedata)
    ax2.plot(xx, yy, "--", color="gray", lw=3, zorder=100)

ax2.set_ylim(4,24)
    

plt.title(r"(a) Anomalous Days (1977-2015)",fontsize=16)

#---Compute Correlation
syear_ = 1977
eyear_ = 2015
idx1_ = np.where((years>=syear_)&(years<=eyear_))
idx2_ = np.where((years2>=syear_)&(years2<=eyear_))
cor, pval = mystat.correlation_test(alldata[idx1_],amedata,0.05)

ax2.text(0.6,0.025,"R(black vs. gray)=%.2f" % (cor), transform=ax2.transAxes, fontsize=16)
plt.legend(loc=2, prop={"size":12}, bbox_to_anchor=(0.35,1.005))

plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.xticks(np.arange(1975,2020,5))


plt.tight_layout()
plt.savefig("./Fig4a.png")
